import csv
import time
import sys
import os
from unittest import case

from util.graphdb_base import GraphDBBase

class SfariImporter(GraphDBBase):
    def __init__(self, argv):
        super().__init__(command=__file__, argv=argv)

    def import_diagnoses(self, file):
        print("importing diagnoses")
        with open(file, '+r') as diagnoses_file:
            reader = csv.DictReader(diagnoses_file, delimiter="\t")
            with self._driver.session() as session:
                self.execute_without_exception("CREATE CONSTRAINT ON (a:Diagnosis) ASSERT a.diagnosisId IS UNIQUE; ")
                self.execute_without_exception("CREATE CONSTRAINT ON (a:Diagnosis) ASSERT a.name IS UNIQUE; ")

                tx = session.begin_transaction()
                i = 0
                for row in reader:
                    try:
                        if row:
                            query = """CREATE (diagnosis:Diagnosis {diagnosisId: $diagnosisId, name: $name})
                            """

                            tx.run(query, {"diagnosisId": row['id'], "name": row['name']})
                            i += 1

                    except Exception as e:
                        print(e, row, reader.line_num)
                        exit

                tx.commit()
                print(i, "lines processed")

    def import_participants(self, file):
        print("importing patients")
        with open(file, '+r') as patients_file:
            reader = csv.DictReader(patients_file, delimiter="\t")
            with self._driver.session() as session:
                self.execute_without_exception("CREATE CONSTRAINT ON (a:Participant) ASSERT a.participantId IS UNIQUE; ")

                tx = session.begin_transaction()
                i = 0
                j = 0
                for row in reader:
                    try:
                        if row:
                            query = """CREATE (participant:Participant {participantId: $participantId, sourceId: $sourceId, age: $age, sex: $sex, familyProfile: $familyProfile})
                            """
                            
                            tx.run(query, {"participantId": row['id'], "sourceId": row['source_id'], "age": row['age'], "sex": row['gender'], "familyProfile": row['family_profile']})

                            i += 1
                            j += 1

                        if i == 1000:
                            tx.commit()
                            print(j, "participants processed")
                            i = 0
                            tx = session.begin_transaction()

                    except Exception as e:
                        print(e, row, reader.line_num)
                        exit

                tx.commit()
                print(j, "lines processed")

    def import_diagnosis_participant(self,file):
        print("importing patient-diagnosis relations")
        with open(file, '+r') as diagnosis_patient_file:
            reader = csv.DictReader(diagnosis_patient_file, delimiter="\t")
            with self._driver.session() as session:

                tx = session.begin_transaction()
                i = 0
                j = 0
                for row in reader:
                    try:
                        if row:
                            query = """
                                MATCH (p:Participant {participantId: $participantId})
                                MATCH (d:Diagnosis {diagnosisId: $diagnosisId})
                                MERGE (p)-[:HAS]->(d)
                            """
                            tx.run(query, {"participantId": row['patient_id'], "diagnosisId": row['diagnosis_id']})

                            i += 1
                            j += 1

                        if i == 1000:
                            tx.commit()
                            print(j, "participant-diagnosis processed")
                            i = 0
                            tx = session.begin_transaction()

                    except Exception as e:
                        print(e, row, reader.line_num)
                        exit

                tx.commit()
                print(j, "lines processed")

    def import_genes(self, file):
        print("importing genes")
        with open(file, '+r') as genes_file:
            reader = csv.DictReader(genes_file, delimiter="\t")
            with self._driver.session() as session:

                self.execute_without_exception("CREATE CONSTRAINT ON (a:Gene) ASSERT a.geneId IS UNIQUE; ")
                self.execute_without_exception("CREATE CONSTRAINT ON (a:Gene) ASSERT a.ensemblId IS UNIQUE; ")

                tx = session.begin_transaction()
                i = 0
                j = 0
                for row in reader:
                    try:
                        if row:
                            query = """CREATE (gene:Gene {geneId: $geneId, ensemblId: $ensemblId, name: $name, chromosome: $chromosome, start: $start, stop: $stop, band: $band, genomeBuild: $genomeBuild})
                            """

                            tx.run(query, {"geneId": row['id'], "ensemblId": row['ensembl_id'], "name": row['name'], "chromosome": row['chromosome'], "start": row['start'], "stop": row['stop'], "band": row['band'], "genomeBuild": row['genome_build']})

                            i += 1
                            j += 1

                        if i == 1000: 
                            tx.commit()
                            print(j, "genes imported")
                            i = 0
                            tx = session.begin_transaction()
                    
                    except Exception as e:
                        print(e, row, reader.line_num)
                        exit
                
                tx.commit()
                print(j, "genes imported")

    def import_participant_gene(self, file):
        with open(file, '+r') as participant_gene_file:
            reader = csv.DictReader(participant_gene_file, delimiter="\t")
            with self._driver.session() as session:

                tx = session.begin_transaction()
                i = 0
                j = 0
                variation_types = {"1": "duplication", "2": "deletion", "3": "unknown", "4": "triplication"}
                inheritances = {"1": "paternal", "2": "both", "3": "unknown", "4": "maternal"}

                for row in reader:
                    try:
                        if row:
                            query = ""
                            if row['variation_type_id'] == "1":
                                query = """
                                MATCH (p:Participant {participantId: $participantId})
                                MATCH (g:Gene {geneId: $geneId})
                                MERGE (p)-[:DUPLICATED {geneDosage: 2, inheritance: $inheritance}]->(g)
                                """
                            elif row['variation_type_id'] == "2":
                                query = """
                                MATCH (p:Participant {participantId: $participantId})
                                MATCH (g:Gene {geneId: $geneId})
                                MERGE (p)-[:DELETED {geneDosage: -1, inheritance: $inheritance}]->(g)
                                """
                            elif row['variation_type_id'] == "3":
                                query = """
                                MATCH (p:Participant {participantId: $participantId})
                                MATCH (g:Gene {geneId: $geneId})
                                MERGE (p)-[:UNKNOWN_DOSAGE {inheritance: $inheritance}]->(g)
                                """
                            elif row['variation_type_id'] == "4":
                                query = """
                                MATCH (p:Participant {participantId: $participantId})
                                MATCH (g:Gene {geneId: $geneId})
                                MERGE (p)-[:TRIPLICATED {geneDosage: 3, inheritance: $inheritance}]->(g)
                                """

                            inheritance = inheritances.get(row['inheritance_id'])
                            tx.run(query, {"participantId": row['patient_id'], "geneId": row['gene_id'], "inheritance": inheritance})

                            i += 1
                            j += 1

                        if i == 1000: 
                            tx.commit()
                            print(j, "participant-gene imported")
                            i = 0
                            tx = session.begin_transaction()
                    
                    except Exception as e:
                        print(e, row, reader.line_num)
                        exit
                
                tx.commit()
                print(j, "participant-gene imported")

    def import_go(self, file):
        with open(file, '+r') as go_file:
            reader = csv.DictReader(go_file, delimiter="\t")
            with self._driver.session() as session:
                self.execute_without_exception("CREATE CONSTRAINT ON (a:Go) ASSERT a.goId IS UNIQUE; ")
                self.execute_without_exception("CREATE CONSTRAINT ON (a:Go) ASSERT a.goAnnotationId IS UNIQUE; ")

                tx = session.begin_transaction()
                i = 0
                j = 0

                for row in reader:
                    try:
                        if row:
                            query = ""
                            if row['domain'] == "molecular_function":
                                query = """CREATE (go:Go:MolecularFunction {goAnnotationId: $goAnnotationId, goId: $goId, domain: $domain, evidenceCode: $evidenceCode, name: $name})
                                """
                            elif row['domain'] == "cellular_component":
                                query = """CREATE (go:Go:CellularComponent {goAnnotationId: $goAnnotationId, goId: $goId, domain: $domain, evidenceCode: $evidenceCode, name: $name})
                                """
                            elif row['domain'] == "biological_process":
                                query = """CREATE (go:Go:BiologicalProcess {goAnnotationId: $goAnnotationId, goId: $goId, domain: $domain, evidenceCode: $evidenceCode, name: $name})
                                """

                            tx.run(query, {"goAnnotationId": row['id'], "goId": row['go_id'], "domain": row['domain'], "evidenceCode": row['evidence_code'], "name": row['name']})

                            i += 1
                            j += 1

                        if i == 1000: 
                            tx.commit()
                            print(j, "goAnnotations imported")
                            i = 0
                            tx = session.begin_transaction()
                    
                    except Exception as e:
                        print(e, row, reader.line_num)
                        exit
                
                tx.commit()
                print(j, "goAnnotations imported")

    def import_gene_go(self, file):
        with open(file, '+r') as gene_go_file:
            reader = csv.DictReader(gene_go_file, delimiter="\t")
            with self._driver.session() as session:

                tx = session.begin_transaction()
                i = 0
                j = 0

                for row in reader:
                    try:
                        if row:
                            query = """
                                MATCH (gene:Gene {geneId: $geneId})
                                MATCH (go:Go {goAnnotationId: $goAnnotationId})
                                WITH gene, go
                                CALL apoc.do.when(
                                    go.domain = 'cellular_component',
                                    "MERGE (gene)-[:PART_OF]->(go)",
                                    "",
                                    {gene: gene, go:go}
                                ) YIELD value
                                WITH gene, go
                                CALL apoc.do.when(
                                    go.domain = 'molecular_function',
                                    "MERGE (gene)-[:ENABLES]->(go)",
                                    "",
                                    {gene: gene, go:go}
                                ) YIELD value
                                WITH gene, go
                                CALL apoc.do.when(
                                    go.domain = 'biological_process',
                                    "MERGE (gene)-[:INVOLVED_IN]->(go)",
                                    "",
                                    {gene: gene, go:go}
                                ) YIELD value
                                RETURN 1
                            """

                            tx.run(query, {"goAnnotationId": row['go_annotation_id'], "geneId": row['gene_id']})

                            i += 1
                            j += 1

                        if i == 1000: 
                            tx.commit()
                            print(j, "gene-goAnnotation imported")
                            i = 0
                            tx = session.begin_transaction()
                    
                    except Exception as e:
                        print(e, row, reader.line_num)
                        exit
                
                tx.commit()
                print(j, "gene-goAnnotation imported")

if __name__ == '__main__':
    importing = SfariImporter(argv=sys.argv[1:])


    base_path = "../dataset/import/"

    if not os.path.isdir(base_path):
            print(base_path, "isn't a directory")
            sys.exit(1)

    diagnosis_path = os.path.join(base_path, "diagnosis.csv")
    participant_path = os.path.join(base_path, "patient.csv")
    diagnosis_participant_path = os.path.join(base_path, "diagnosispatient.csv")
    gene_path = os.path.join(base_path, "gene.csv")
    participant_gene_path = os.path.join(base_path, "patientgene.csv")
    goannotation_path = os.path.join(base_path, "goannotation.csv")
    gene_goannotation_path = os.path.join(base_path, "genegoannotation.csv")

    start = time.time()
    importing.import_diagnoses(diagnosis_path)
    importing.import_participants(participant_path)
    importing.import_diagnosis_participant(diagnosis_participant_path)
    importing.import_genes(gene_path)
    importing.import_participant_gene(participant_gene_path)
    importing.import_go(goannotation_path)
    importing.import_gene_go(gene_goannotation_path)
    end = time.time() - start
    importing.close()
    print("Time to complete:", end)